import random
from abc import abstractmethod
from enum import IntEnum
from typing import Iterable, Union, Optional, Any, SupportsFloat

import numpy as np
import gymnasium as gym
from gymnasium.core import ObsType, ActType
from gymnasium import spaces
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Goal, Ball, Lava, Wall
from minigrid.manual_control import ManualControl
from minigrid.minigrid_env import MiniGridEnv
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper


class EvaluationActions(IntEnum):
    # Turn left, turn right, move forward
    left = 0
    right = 1
    forward = 2
    # 左右平移
    move_left = 3
    move_right = 4
    # 斜向移动
    move_forward_left = 5
    move_forward_right = 6
    # Pick up an object
    pickup = 7
    # Drop an object
    drop = 8
    # Toggle/activate an object
    toggle = 9


class ActionGridEnv(MiniGridEnv):
    def __init__(self, name="MiniGrid-Empty-8x8-a3", size=8,
                 agent_start_pos: Iterable[Union[int, Iterable]] = None,
                 obj_pos: Iterable[Union[int, Iterable]] = None, n_walls=0,n_lavas=0,
                 agent_start_dir: Optional[int] = None, max_steps: Optional[int] = None, seed=None, **kwargs):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.obj_pos = obj_pos
        if isinstance(name, str):
            self.mission_name = name
        else:
            self.mission_name = self._gen_mission()

        self.state_dict = None  # 存储环境状态快照

        self.seed = seed  # 为None表示每次环境初始化都会随机，设置为int表示每次环境初始化都固定与seed相关

        agent_view_size = 7
        if "v3" in name:
            agent_view_size = 3
        elif "v5" in name:
            agent_view_size = 5

        if "9x9" in name:
            size = 9
        elif "16x16" in name:
            size = 16

        self.n_lavas = n_lavas
        self.n_walls = n_walls
        if "Lava" in name:
            self.n_lavas = 5 # 默认的熔岩数量是5
        if "Wall" in name:
            self.n_walls = 5 # 默认的墙壁数量是5

        # 根据环境大小减少障碍物数量，减少死局概率
        if n_walls > size / 2 + 1:
            self.n_walls= int(size / 2)

        # 生成任务空间
        mission_space = MissionSpace(mission_func=lambda: self.mission_name)

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            max_steps=64 if max_steps is None else max_steps,
            agent_view_size=agent_view_size,
            **kwargs,
        )
        self.actions = EvaluationActions

        # 不同大小的动作空间
        if name.endswith("a3"):
            self.action_space = spaces.Discrete(3)
        elif name.endswith("a5"):
            self.action_space = spaces.Discrete(5)
        elif name.endswith("a7"):
            self.action_space = spaces.Discrete(7)
        else:
            self.action_space = spaces.Discrete(len(self.actions))

    def save_state(self):
        # 保存当前环境状态并返回快照
        state_dict = {"agent_pos": self.agent_pos,
                      "agent_dir": self.agent_dir,
                      "carrying": self.carrying,
                      "grid": self.grid}
        self.state_dict = state_dict
        return state_dict

    def load_state(self, state_dict):
        # 加载当前环境状态
        assert "agent_pos" in state_dict
        assert "agent_dir" in state_dict
        assert "carrying" in self.carrying
        assert "grid" in self.carrying
        # TODO more checking
        self.state_dict = state_dict

    def reset(
            self,
            *,
            seed: Optional[int] = None,
            options: Optional[dict[str, Any]] = None,
    ) -> tuple[ObsType, dict[str, Any]]:
        obs, info = super().reset(seed=seed, options=options)
        if self.state_dict is not None:
            self.carrying = self.state_dict["carrying"]  # 环境重置之后再添加携带物品
        return obs, info

    @staticmethod
    def _gen_mission():
        # 任务函数返回对应的任务标志
        return "get to the green goal square"

    # MiniGridEnv._gen_grid
    @abstractmethod
    def _gen_grid(self, width, height):
        self.mission = self.mission_name

        # print(self.state_dict)
        if self.state_dict is not None:
            # 直接读取保存的环境状态而不是根据参数生成
            self.agent_pos = self.state_dict['agent_pos']
            self.agent_dir = self.state_dict['agent_dir']
            self.grid = self.state_dict['grid']
            return

        # 根据参数生成环境
        if self.seed is not None:
            random.seed(self.seed)

        # Create an empty grid
        self.grid = Grid(width, height)

        # 生成围绕环境的墙壁
        self.grid.wall_rect(0, 0, width, height)

        # 设置目标位置
        if self.obj_pos is not None:
            if isinstance(self.obj_pos[0], Iterable):  # 目标位置范围
                obj_pos_x = random.randint(self.obj_pos[0][0], self.obj_pos[0][1])
                obj_pos_y = random.randint(self.obj_pos[1][0], self.obj_pos[1][1])
            else:
                obj_pos_x = self.obj_pos[0] if self.obj_pos[0] >= 0 else width + self.obj_pos[0]
                obj_pos_y = self.obj_pos[1] if self.obj_pos[1] >= 0 else height + self.obj_pos[1]
            self.put_obj(Goal(), obj_pos_x, obj_pos_y)
        else:
            self.place_obj(Goal())  # 随机放置目标位置
        # 设置智能体位置
        if self.agent_start_pos is not None:
            if isinstance(self.agent_start_pos[0], Iterable):  # 指定范围而不是固定值
                self.agent_pos = (random.randint(self.agent_start_pos[0][0], self.agent_start_pos[0][1]),
                                  random.randint(self.agent_start_pos[1][0], self.agent_start_pos[1][1]))
            else:
                self.agent_pos = self.agent_start_pos

        else:
            self.place_agent()  # 随机放在空位置，并设置随机方向

        # 设置智能体方向
        if self.agent_start_dir is not None:
            self.agent_dir = self.agent_dir

        # 放置障碍物
        self.walls = []
        for i_obst in range(self.n_walls):
            self.walls.append(Wall())
            self.place_obj(self.walls[i_obst], max_tries=100)

        # 放置岩浆
        self.lavas = []
        for i_obst in range(self.n_lavas):
            self.lavas.append(Lava())
            self.place_obj(self.lavas[i_obst], max_tries=100)


    # def get_full_render(self, highlight, tile_size):
    #     """
    #     Render a non-paratial observation for visualization
    #     """
    #     # Compute which cells are visible to the agent
    #     _, vis_mask = self.gen_obs_grid()
    #
    #     # Compute the world coordinates of the bottom-left corner
    #     # of the agent's view area
    #     f_vec = self.dir_vec
    #     r_vec = self.right_vec
    #     top_left = (
    #             self.agent_pos
    #             + f_vec * (self.agent_view_size - 1)
    #             - r_vec * (self.agent_view_size // 2)
    #     )
    #     # Render the whole grid
    #     img = self.grid.render(
    #         tile_size,
    #         self.agent_pos,
    #         self.agent_dir,
    #         highlight_mask=None,
    #     )
    #
    #     return img

    def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        self.step_count += 1

        reward = 0
        terminated = False
        truncated = False

        # 获取智能体前面的位置
        fwd_pos = self.front_pos

        # 获取智能体前面单元格的内容
        fwd_cell = self.grid.get(*fwd_pos)

        def deal_next_cell(next_cell,next_pos, terminated, reward):
            # 统一对移动操作进行处理
            # 如果下一个单元格是可移动的，则移动到该单元格
            if next_cell is None or next_cell.can_overlap():
                self.agent_pos = tuple(next_pos)
            # 如果下一个单元格是目标，则终止任务，并得到奖励
            if next_cell is not None and  next_cell.type == "goal":
                terminated = True
                reward = self._reward()
            # 如果下一个单元格是岩浆，则获得惩罚并移动
            elif next_cell is not None and next_cell.type == "lava":
                # terminated = True
                self.agent_pos = tuple(next_pos)
                reward = -0.2
            # 如果下一个单元格是障碍物或墙，则移动无效
            return terminated, reward

        # 向左旋转
        if action == self.actions.left:
            self.agent_dir -= 1
            if self.agent_dir < 0:
                self.agent_dir += 4

        # 向右旋转
        elif action == self.actions.right:
            self.agent_dir = (self.agent_dir + 1) % 4

        # 向前移动
        elif action == self.actions.forward:
            terminated, reward = deal_next_cell(fwd_cell, fwd_pos,terminated, reward)

        elif self.action_space.n > 3 and (action == self.actions.move_left or action == self.actions.move_right):
            # 更多移动方式：横向平移
            if action == self.actions.move_left:
                hor_pos = self.left_pos()
            else:
                hor_pos = self.right_pos()

            hor_cell = self.grid.get(*hor_pos)
            terminated, reward = deal_next_cell(hor_cell,hor_pos,terminated, reward)

        elif self.action_space.n > 5 and (
                action == self.actions.move_forward_left or action == self.actions.move_forward_right):
            # 更多移动方式：斜向
            if action == self.actions.move_forward_left:
                hor_pos = self.forward_left_pos()
            else:
                hor_pos = self.forward_right_pos()

            hor_cell = self.grid.get(*hor_pos)
            terminated, reward = deal_next_cell(hor_cell,hor_pos, terminated, reward)

        elif self.action_space.n > 7:
            # 拾取一个对象
            if action == self.actions.pickup:
                if fwd_cell and fwd_cell.can_pickup():
                    if self.carrying is None:
                        self.carrying = fwd_cell
                        self.carrying.cur_pos = np.array([-1, -1])
                        self.grid.set(fwd_pos[0], fwd_pos[1], None)

            # 放下一个对象
            elif action == self.actions.drop:
                if not fwd_cell and self.carrying:
                    self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
                    self.carrying.cur_pos = fwd_pos
                    self.carrying = None

            # 切换/激活对象
            elif action == self.actions.toggle:
                if fwd_cell:
                    fwd_cell.toggle(self, fwd_pos)

        else:
            raise ValueError(f"Unknown action: {action}")

        if self.step_count >= self.max_steps:
            truncated = True

        if self.render_mode == "human":
            self.render()

        obs = self.gen_obs()

        return obs, reward, terminated, truncated, {}

    @property
    def left_vec(self):
        """
        获取指向代理左侧的矢量。
        """

        dx, dy = self.dir_vec
        return np.array((dy, -dx))

    def left_pos(self):
        """
        获取智能体左边单元格的位置
        """
        return self.agent_pos + self.left_vec

    def right_pos(self):
        """
        获取智能体右单元格的位置
        """
        return self.agent_pos + self.right_vec

    def forward_left_pos(self):
        """
        获取智能体左前方单元格的位置
        """
        return self.agent_pos + self.left_vec + self.dir_vec

    def forward_right_pos(self):
        """
        获取智能体右前方单元格的位置
        """
        return self.agent_pos + self.right_vec + self.dir_vec

